Merge pull request #944 from CodingCat/scala_examples

Scala examples
This commit is contained in:
Tianqi Chen 2016-03-09 10:08:07 -08:00
commit db7a4e2ada
14 changed files with 663 additions and 13 deletions

View File

@ -11,6 +11,17 @@ XGBoost4J Code Examples
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
## Scala API
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala)
* [Cutomize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala)
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala)
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/scala/example/PredictFirstNtree.scala)
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala)
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.scala)
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala)
* [External Memory](src/main/java/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala)
## Spark API
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala)

View File

@ -0,0 +1,98 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import java.io.File
import java.util
import scala.collection.mutable
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.example.util.DataLoader
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
class BasicWalkThrough {
def main(args: Array[String]): Unit = {
import BasicWalkThrough._
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "silent" -> 1
params += "objective" -> "binary:logistic"
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMax
watches += "test" -> testMax
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap)
// predict
val predicts = booster.predict(testMax)
// save model to model path
val file = new File("./model")
if (!file.exists()) {
file.mkdirs()
}
booster.saveModel(file.getAbsolutePath + "/xgb.model")
// dump model
booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false)
// dump model with feature map
booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
// save dmatrix into binary buffer
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")
// reload model and data
val booster2 = XGBoost.loadModel(file.getAbsolutePath + "/xgb.model")
val testMax2 = new DMatrix(file.getAbsolutePath + "/dtest.buffer")
val predicts2 = booster2.predict(testMax2)
// check predicts
println(checkPredicts(predicts, predicts2))
// build dmatrix from CSR Sparse Matrix
println("start build dmatrix from csr sparse data ...")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
JDMatrix.SparseType.CSR)
trainMax2.setLabel(spData.labels)
// specify watchList
val watches2 = new mutable.HashMap[String, DMatrix]
watches2 += "train" -> trainMax2
watches2 += "test" -> testMax2
val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null)
val predicts3 = booster3.predict(testMax2)
println(checkPredicts(predicts, predicts3))
}
}
object BasicWalkThrough {
def checkPredicts(fPredicts: Array[Array[Float]], sPredicts: Array[Array[Float]]): Boolean = {
require(fPredicts.length == sPredicts.length, "the comparing predicts must be with the same " +
"length")
for (i <- fPredicts.indices) {
if (!java.util.Arrays.equals(fPredicts(i), sPredicts(i))) {
return false
}
}
true
}
}

View File

@ -0,0 +1,53 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
class BoostFromPrediction {
def main(args: Array[String]): Unit = {
println("start running example to start from a initial prediction")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "silent" -> 1
params += "objective" -> "binary:logistic"
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
val trainPred = booster.predict(trainMat, true)
val testPred = booster.predict(testMat, true)
trainMat.setBaseMargin(trainPred)
testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
}
}

View File

@ -0,0 +1,46 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
class CrossValidation {
def main(args: Array[String]): Unit = {
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
// set params
val params = new mutable.HashMap[String, Any]
params.put("eta", 1.0)
params.put("max_depth", 3)
params.put("silent", 1)
params.put("nthread", 6)
params.put("objective", "binary:logistic")
params.put("gamma", 1.0)
params.put("eval_metric", "error")
// do 5-fold cross validation
val round: Int = 2
val nfold: Int = 5
// set additional eval_metrics
val metrics: Array[String] = null
val evalHist: Array[String] =
XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null)
}
}

View File

@ -0,0 +1,157 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix, EvalTrait, ObjectiveTrait}
import org.apache.commons.logging.{LogFactory, Log}
/**
* an example user define objective and eval
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* he buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation
* function
*
*/
class CustomObjective {
/**
* loglikelihoode loss obj function
*/
class LogRegObj extends ObjectiveTrait {
private val logger: Log = LogFactory.getLog(classOf[LogRegObj])
/**
* user define objective function, return gradient and second order gradient
*
* @param predicts untransformed margin predicts
* @param dtrain training data
* @return List with two float array, correspond to first order grad and second order grad
*/
override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix)
: List[Array[Float]] = {
val nrow = predicts.length
val gradients = new ListBuffer[Array[Float]]
var labels: Array[Float] = null
try {
labels = dtrain.getLabel
} catch {
case e: XGBoostError =>
logger.error(e)
null
case _ =>
null
}
val grad = new Array[Float](nrow)
val hess = new Array[Float](nrow)
val transPredicts = transform(predicts)
for (i <- 0 until nrow) {
val predict = transPredicts(i)(0)
grad(i) = predict - labels(i)
hess(i) = predict * (1 - predict)
}
gradients += grad
gradients += hess
gradients.toList
}
/**
* simple sigmoid func
*
* @param input
* @return Note: this func is not concern about numerical stability, only used as example
*/
def sigmoid(input: Float): Float = {
(1 / (1 + Math.exp(-input))).toFloat
}
def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = {
val nrow = predicts.length
val transPredicts = Array.fill[Float](nrow, 1)(0)
for (i <- 0 until nrow) {
transPredicts(i)(0) = sigmoid(predicts(i)(0))
}
transPredicts
}
}
class EvalError extends EvalTrait {
val logger = LogFactory.getLog(classOf[EvalError])
private[xgboost4j] var evalMetric: String = "custom_error"
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = evalMetric
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
error += 1
}
}
error / labels.length
}
}
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "silent" -> 1
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError)
}
}

View File

@ -0,0 +1,59 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
class ExternalMemory {
def main(args: Array[String]): Unit = {
// this is the only difference, add a # followed by a cache prefix name
// several cache file with the prefix will be generated
// currently only support convert from libsvm file
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "silent" -> 1
params += "objective" -> "binary:logistic"
// performance notice: set nthread to be the number of your real cpu
// some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
// set nthread=4
// param.put("nthread", num_real_cpu);
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 2
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
val trainPred = booster.predict(trainMat, true)
val testPred = booster.predict(testMat, true)
trainMat.setBaseMargin(trainPred)
testMat.setBaseMargin(testPred)
System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
}
}

View File

@ -0,0 +1,60 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
/**
* this is an example of fit generalized linear model in xgboost
* basically, we are using linear model, instead of tree for our boosters
*/
class GeneralizedLinearModel {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
// specify parameters
// change booster to gblinear, so that we are fitting a linear model
// alpha is the L1 regularizer
// lambda is the L2 regularizer
// you can also set lambda_bias which is L2 regularizer on the bias term
val params = new mutable.HashMap[String, Any]()
params += "alpha" -> 0.0001
params += "boosterh" -> "gblinear"
params += "silent" -> 1
params += "objective" -> "binary:logistic"
// normally, you do not need to set eta (step_size)
// XGBoost uses a parallel coordinate descent algorithm (shotgun),
// there could be affection on convergence with parallelization on certain cases
// setting eta to be smaller value, e.g 0.5 can make the optimization more stable
// param.put("eta", "0.5");
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 4
val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
val predicts = booster.predict(testMat)
val eval = new CustomEval
println(s"error=${eval.eval(predicts, testMat)}")
}
}

View File

@ -0,0 +1,53 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
class PredictFirstNTree {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "silent" -> 1
params += "objective" -> "binary:logistic"
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 3
// train a model
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
// predict use 1 tree
val predicts1 = booster.predict(testMat, false, 1)
// by default all trees are used to do predict
val predicts2 = booster.predict(testMat)
val eval = new CustomEval
println("error of predicts1: " + eval.eval(predicts1, testMat))
println("error of predicts2: " + eval.eval(predicts2, testMat))
}
}

View File

@ -0,0 +1,56 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example
import java.util
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
class PredictLeafIndices {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "silent" -> 1
params += "objective" -> "binary:logistic"
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMat
watches += "test" -> testMat
val round = 3
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
// predict using first 2 tree
val leafIndex = booster.predictLeaf(testMat, 2)
for (leafs <- leafIndex) {
println(java.util.Arrays.toString(leafs))
}
// predict all trees
val leafIndex2 = booster.predictLeaf(testMat, 0)
for (leafs <- leafIndex) {
println(java.util.Arrays.toString(leafs))
}
}
}

View File

@ -13,11 +13,10 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink.example
package ml.dmlc.xgboost4j.scala.example.flink
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
import org.apache.flink.ml.MLUtils
object DistTrainWithFlink {

View File

@ -14,11 +14,11 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.example
package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.MLUtils
import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark {
def main(args: Array[String]): Unit = {

View File

@ -0,0 +1,60 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example.util
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.{Log, LogFactory}
class CustomEval extends EvalTrait {
private val logger: Log = LogFactory.getLog(classOf[CustomEval])
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = {
"custom_error"
}
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0.5) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0.5) {
error += 1
}
}
error / labels.length
}
}

View File

@ -107,7 +107,6 @@ object XGBoost extends Serializable {
def loadModelFromHadoop(modelPath: String)(implicit sparkContext: SparkContext): XGBoostModel = {
val dataInStream = FileSystem.get(sparkContext.hadoopConfiguration).open(new Path(modelPath))
val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream))
dataInStream.close()
xgBoostModel
}
}

View File

@ -26,12 +26,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
/**
* Predict result with the given testset (represented as RDD)
*/
* Predict result with the given testset (represented as RDD)
*/
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples =>
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predict(dMatrix))
@ -46,10 +45,10 @@ class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Seri
}
/**
* Save the model as to HDFS-compatible file system.
*
* @param modelPath The model path as in Hadoop path.
*/
* Save the model as to HDFS-compatible file system.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelToHadoop(modelPath: String): Unit = {
val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
booster.saveModel(outputStream)