fix examples

This commit is contained in:
CodingCat
2016-03-11 13:57:03 -05:00
parent aca0096b33
commit ab68a0ccc7
5 changed files with 29 additions and 14 deletions

View File

@@ -81,13 +81,14 @@ object XGBoost {
/**
* Train a xgboost model with link.
*
* @param params The parameters to XGBoost.
* @param dtrain The training data.
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(params: Map[String, Any],
dtrain: DataSet[LabeledVector],
round: Int): XGBoostModel = {
def train(
dtrain: DataSet[LabeledVector],
params: Map[String, Any],
round: Int): XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start()) {
dtrain

View File

@@ -37,6 +37,15 @@ class XGBoostModel (booster: Booster) extends Serializable {
.create(new Path(modelPath)))
}
/**
* predict with the given DMatrix
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet, true, 0)
}
/**
* Predict given vector dataset.
*
@@ -44,7 +53,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip