From 8cfa752fa08e4b66409113ba67da3a92b64fd111 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 8 Mar 2016 19:45:33 -0500 Subject: [PATCH] add scala examples --- .../scala/example/BasicWalkThrough.scala | 98 +++++++++++ .../scala/example/BoostFromPrediction.scala | 53 ++++++ .../scala/example/CrossValidation.scala | 46 +++++ .../scala/example/CustomObjective.scala | 157 ++++++++++++++++++ .../scala/example/ExternalMemory.scala | 59 +++++++ .../example/GeneralizedLinearModel.scala | 60 +++++++ .../scala/example/PredictFirstNTree.scala | 53 ++++++ .../scala/example/PredictLeafIndices.scala | 56 +++++++ .../scala/example/util/CustomEval.scala | 60 +++++++ 9 files changed, 642 insertions(+) create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala create mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/util/CustomEval.scala diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala new file mode 100644 index 000000000..fdfb50c94 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala @@ -0,0 +1,98 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.example + +import java.io.File +import java.util + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.java.example.util.DataLoader +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} + +class BasicWalkThrough { + def main(args: Array[String]): Unit = { + import BasicWalkThrough._ + val trainMax = new DMatrix("../../demo/data/agaricus.txt.train") + val testMax = new DMatrix("../../demo/data/agaricus.txt.test") + + val params = new mutable.HashMap[String, Any]() + params += "eta" -> 1.0 + params += "max_depth" -> 2 + params += "silent" -> 1 + params += "objective" -> "binary:logistic" + + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMax + watches += "test" -> testMax + + val round = 2 + // train a model + val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap) + // predict + val predicts = booster.predict(testMax) + // save model to model path + val file = new File("./model") + if (!file.exists()) { + file.mkdirs() + } + booster.saveModel(file.getAbsolutePath + "/xgb.model") + // dump model + booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false) + // dump model with feature map + booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false) + // save dmatrix into binary buffer + testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer") + + // reload model and data + val booster2 = XGBoost.loadModel(file.getAbsolutePath + "/xgb.model") + val testMax2 = new DMatrix(file.getAbsolutePath + "/dtest.buffer") + val predicts2 = booster2.predict(testMax2) + + // check predicts + println(checkPredicts(predicts, predicts2)) + + // build dmatrix from CSR Sparse Matrix + println("start build dmatrix from csr sparse data ...") + val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train") + val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, + JDMatrix.SparseType.CSR) + trainMax2.setLabel(spData.labels) + + // specify watchList + val watches2 = new mutable.HashMap[String, DMatrix] + watches2 += "train" -> trainMax2 + watches2 += "test" -> testMax2 + val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null) + val predicts3 = booster3.predict(testMax2) + println(checkPredicts(predicts, predicts3)) + } +} + +object BasicWalkThrough { + def checkPredicts(fPredicts: Array[Array[Float]], sPredicts: Array[Array[Float]]): Boolean = { + require(fPredicts.length == sPredicts.length, "the comparing predicts must be with the same " + + "length") + for (i <- fPredicts.indices) { + if (!java.util.Arrays.equals(fPredicts(i), sPredicts(i))) { + return false + } + } + true + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala new file mode 100644 index 000000000..a68d479c1 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala @@ -0,0 +1,53 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.example + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} + +class BoostFromPrediction { + def main(args: Array[String]): Unit = { + println("start running example to start from a initial prediction") + + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + + val params = new mutable.HashMap[String, Any]() + params += "eta" -> 1.0 + params += "max_depth" -> 2 + params += "silent" -> 1 + params += "objective" -> "binary:logistic" + + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMat + watches += "test" -> testMat + + val round = 2 + // train a model + val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + + val trainPred = booster.predict(trainMat, true) + val testPred = booster.predict(testMat, true) + + trainMat.setBaseMargin(trainPred) + testMat.setBaseMargin(testPred) + + System.out.println("result of running from initial prediction") + val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala new file mode 100644 index 000000000..493aa2e62 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala @@ -0,0 +1,46 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.scala.example + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} + +class CrossValidation { + def main(args: Array[String]): Unit = { + val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train") + + // set params + val params = new mutable.HashMap[String, Any] + + params.put("eta", 1.0) + params.put("max_depth", 3) + params.put("silent", 1) + params.put("nthread", 6) + params.put("objective", "binary:logistic") + params.put("gamma", 1.0) + params.put("eval_metric", "error") + + // do 5-fold cross validation + val round: Int = 2 + val nfold: Int = 5 + // set additional eval_metrics + val metrics: Array[String] = null + + val evalHist: Array[String] = + XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null) + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala new file mode 100644 index 000000000..3f27e9031 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala @@ -0,0 +1,157 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.scala.example + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix, EvalTrait, ObjectiveTrait} +import org.apache.commons.logging.{LogFactory, Log} + +/** + * an example user define objective and eval + * NOTE: when you do customized loss function, the default prediction value is margin + * this may make buildin evalution metric not function properly + * for example, we are doing logistic loss, the prediction is score before logistic transformation + * he buildin evaluation error assumes input is after logistic transformation + * Take this in mind when you use the customization, and maybe you need write customized evaluation + * function + * + */ +class CustomObjective { + + /** + * loglikelihoode loss obj function + */ + class LogRegObj extends ObjectiveTrait { + private val logger: Log = LogFactory.getLog(classOf[LogRegObj]) + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix) + : List[Array[Float]] = { + val nrow = predicts.length + val gradients = new ListBuffer[Array[Float]] + var labels: Array[Float] = null + try { + labels = dtrain.getLabel + } catch { + case e: XGBoostError => + logger.error(e) + null + case _ => + null + } + val grad = new Array[Float](nrow) + val hess = new Array[Float](nrow) + val transPredicts = transform(predicts) + + for (i <- 0 until nrow) { + val predict = transPredicts(i)(0) + grad(i) = predict - labels(i) + hess(i) = predict * (1 - predict) + } + gradients += grad + gradients += hess + gradients.toList + } + + /** + * simple sigmoid func + * + * @param input + * @return Note: this func is not concern about numerical stability, only used as example + */ + def sigmoid(input: Float): Float = { + (1 / (1 + Math.exp(-input))).toFloat + } + + def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = { + val nrow = predicts.length + val transPredicts = Array.fill[Float](nrow, 1)(0) + for (i <- 0 until nrow) { + transPredicts(i)(0) = sigmoid(predicts(i)(0)) + } + transPredicts + } + + } + + class EvalError extends EvalTrait { + + val logger = LogFactory.getLog(classOf[EvalError]) + + private[xgboost4j] var evalMetric: String = "custom_error" + + /** + * get evaluate metric + * + * @return evalMetric + */ + override def getMetric: String = evalMetric + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = { + var error: Float = 0f + var labels: Array[Float] = null + try { + labels = dmat.getLabel + } catch { + case ex: XGBoostError => + logger.error(ex) + return -1f + } + val nrow: Int = predicts.length + for (i <- 0 until nrow) { + if (labels(i) == 0.0 && predicts(i)(0) > 0) { + error += 1 + } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) { + error += 1 + } + } + error / labels.length + } + } + + def main(args: Array[String]): Unit = { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + val params = new mutable.HashMap[String, Any]() + params += "eta" -> 1.0 + params += "max_depth" -> 2 + params += "silent" -> 1 + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMat + watches += "test" -> testMat + + val round = 2 + // train a model + val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError) + } + +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala new file mode 100644 index 000000000..61faf3293 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala @@ -0,0 +1,59 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.example + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} + +class ExternalMemory { + def main(args: Array[String]): Unit = { + // this is the only difference, add a # followed by a cache prefix name + // several cache file with the prefix will be generated + // currently only support convert from libsvm file + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache") + + val params = new mutable.HashMap[String, Any]() + params += "eta" -> 1.0 + params += "max_depth" -> 2 + params += "silent" -> 1 + params += "objective" -> "binary:logistic" + + // performance notice: set nthread to be the number of your real cpu + // some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case + // set nthread=4 + // param.put("nthread", num_real_cpu); + + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMat + watches += "test" -> testMat + + val round = 2 + // train a model + val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + + val trainPred = booster.predict(trainMat, true) + val testPred = booster.predict(testMat, true) + + trainMat.setBaseMargin(trainPred) + testMat.setBaseMargin(testPred) + + System.out.println("result of running from initial prediction") + val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala new file mode 100644 index 000000000..580f8351a --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala @@ -0,0 +1,60 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.scala.example + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} +import ml.dmlc.xgboost4j.scala.example.util.CustomEval + + +/** + * this is an example of fit generalized linear model in xgboost + * basically, we are using linear model, instead of tree for our boosters + */ +class GeneralizedLinearModel { + def main(args: Array[String]): Unit = { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + + // specify parameters + // change booster to gblinear, so that we are fitting a linear model + // alpha is the L1 regularizer + // lambda is the L2 regularizer + // you can also set lambda_bias which is L2 regularizer on the bias term + val params = new mutable.HashMap[String, Any]() + params += "alpha" -> 0.0001 + params += "boosterh" -> "gblinear" + params += "silent" -> 1 + params += "objective" -> "binary:logistic" + + // normally, you do not need to set eta (step_size) + // XGBoost uses a parallel coordinate descent algorithm (shotgun), + // there could be affection on convergence with parallelization on certain cases + // setting eta to be smaller value, e.g 0.5 can make the optimization more stable + // param.put("eta", "0.5"); + + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMat + watches += "test" -> testMat + + val round = 4 + val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null) + val predicts = booster.predict(testMat) + val eval = new CustomEval + println(s"error=${eval.eval(predicts, testMat)}") + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala new file mode 100644 index 000000000..8dd83e6c7 --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala @@ -0,0 +1,53 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.scala.example + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.example.util.CustomEval +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} + +class PredictFirstNTree { + + def main(args: Array[String]): Unit = { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + + val params = new mutable.HashMap[String, Any]() + params += "eta" -> 1.0 + params += "max_depth" -> 2 + params += "silent" -> 1 + params += "objective" -> "binary:logistic" + + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMat + watches += "test" -> testMat + + val round = 3 + // train a model + val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + + // predict use 1 tree + val predicts1 = booster.predict(testMat, false, 1) + // by default all trees are used to do predict + val predicts2 = booster.predict(testMat) + + val eval = new CustomEval + println("error of predicts1: " + eval.eval(predicts1, testMat)) + println("error of predicts2: " + eval.eval(predicts2, testMat)) + } + +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala new file mode 100644 index 000000000..d7194c73f --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala @@ -0,0 +1,56 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.example + +import java.util + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix} + +class PredictLeafIndices { + + def main(args: Array[String]): Unit = { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + + val params = new mutable.HashMap[String, Any]() + params += "eta" -> 1.0 + params += "max_depth" -> 2 + params += "silent" -> 1 + params += "objective" -> "binary:logistic" + + val watches = new mutable.HashMap[String, DMatrix] + watches += "train" -> trainMat + watches += "test" -> testMat + + val round = 3 + val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap) + + // predict using first 2 tree + val leafIndex = booster.predictLeaf(testMat, 2) + for (leafs <- leafIndex) { + println(java.util.Arrays.toString(leafs)) + } + + // predict all trees + val leafIndex2 = booster.predictLeaf(testMat, 0) + for (leafs <- leafIndex) { + println(java.util.Arrays.toString(leafs)) + } + } +} diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/util/CustomEval.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/util/CustomEval.scala new file mode 100644 index 000000000..6fb233c2a --- /dev/null +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/util/CustomEval.scala @@ -0,0 +1,60 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.scala.example.util + +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import org.apache.commons.logging.{Log, LogFactory} + +class CustomEval extends EvalTrait { + private val logger: Log = LogFactory.getLog(classOf[CustomEval]) + /** + * get evaluate metric + * + * @return evalMetric + */ + override def getMetric: String = { + "custom_error" + } + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = { + var error: Float = 0f + var labels: Array[Float] = null + try { + labels = dmat.getLabel + } catch { + case ex: XGBoostError => + logger.error(ex) + return -1f + } + val nrow: Int = predicts.length + for (i <- 0 until nrow) { + if (labels(i) == 0.0 && predicts(i)(0) > 0.5) { + error += 1 + } else if (labels(i) == 1.0 && predicts(i)(0) <= 0.5) { + error += 1 + } + } + error / labels.length + } +}